mocov1
b = torch.tensor([
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
])
print(torch.sum(b, dim=0))
tensor([[ 3, 6, 9],
[12, 15, 18]])
print(torch.sum(b, dim=1))
tensor([[5, 7, 9],
[5, 7, 9],
[5, 7, 9]])
print(torch.sum(b, dim=2))
tensor([[ 6, 15],
[ 6, 15],
[ 6, 15]])